跳到主要内容

使用 AST 自动生成 AOP

利用 go/ast 语法树做代码生成

Go 没法像 Java 那样做动态 AOP,但可以通过 go/ast 做代码生成,达成同样目标,而且不像 reflect 会影响性能和静态检查。用的好的话可以极大提高效率,更加自动化,减少手工复粘,也就降低犯错概率。

完整代码:

https://gist.github.com/alsritter/de97e60119daf1a9aac933bef132ca30

需求概述

go.uber.org/zap 日志包性能很好,但是用起来很不方便,虽然新版本添加了 global 方法,但仍然别扭:zap.S().Info()

现在我们的需求就是将 zap 的 sugaredLogger 封装成一个包,让它像 logrus 一样易用,直接调用包内函数:log.Info()

首先安装 ZAP

go get -u go.uber.org/zap

实现原理

我们只需要找到 SugaredLogger 这个 type 拥有的 Exported 方法,将其改为函数,函数体调用其同名方法:

func Info(args ...interface{}) {
zap.S().Debug(args...)
}

获取 ast 语法树

方法可能分散在包内不同 go 文件,所以必须解析整个包,而不是单个文件。

首先要找到 go.uber.org/zap 的源码路径,这里我们极客到底,通过 go/build 包获取其在 gomod 中的路径,不用手动填写:

// getImportPkg 通过 go/build 包获取其在 gomod 中的路径,不用手动填写:
func getImportPkg(pkg string) (string, error) {
p, err := build.Import(pkg, "", build.FindOnly)
if err != nil {
return "", err
}

return p.Dir, err
}

使用测试:

dir, err := getImportPkg("go.uber.org/zap")
if err != nil {
return errx.WithStack(err)
}

log.Printf("dir: %+v", dir)

输出:

2021/12/13 00:11:45 dir: C:\Users\alsritter\go\pkg\mod\go.uber.org\zap@v1.19.1

解析整个 zap 目录,找到对应包名的的 *ast.Package

// parseDir 解析整个 zap 目录,找到对应包名的的 *ast.Package
func parseDir(dir, pkgName string) (*ast.Package, error) {
// 返回的是一个 package name -> package 的 Map
pkgMap, err := parser.ParseDir(
token.NewFileSet(),
dir,
func(info os.FileInfo) bool {
// skip go-test
return !strings.Contains(info.Name(), "_test.go")
},
parser.Mode(0), // no comment
)
if err != nil {
return nil, errx.WithStack(err)
}
pkg, ok := pkgMap[pkgName]
if !ok {
err := errors.New("not found")
return nil, errx.WithStack(err)
}

return pkg, nil
}

使用测试:

// dir: C:\Users\alsritter\go\pkg\mod\go.uber.org\zap@v1.19.1
pkg, err := parseDir(dir, "zap")
if err != nil {
return errx.WithStack(err)
}

找到 SugaredLogger 的方法

遍历 ast,找到 SugaredLogger 的所有 Exported 方法:

// Visit 遍历 ast,找到 SugaredLogger 的所有 Exported 方法(就是公有方法):
func (v *visitor) Visit(node ast.Node) ast.Visitor {
switch n := node.(type) {
// 只需处理 FuncDecl
case *ast.FuncDecl:
if n.Recv == nil ||
!n.Name.IsExported() || // 判断方法名称是否为大写(公开
len(n.Recv.List) != 1 { // 取得当前接收者下有多少个方法
return nil
}
// 判断方法类型,如果不是指针方法类型则失败
t, ok := n.Recv.List[0].Type.(*ast.StarExpr)
if !ok {
return nil
}

if t.X.(*ast.Ident).String() != "SugaredLogger" {
return nil
}

log.Printf("func name: %s", n.Name.String())
v.funcs = append(v.funcs, rewriteFunc(n))
}
return v
}

修改方法的 ast

  • 将方法 Recv 置空,变为函数。
  • 参数名不变,如果为可变参数,则加上展开符 ...。
  • 函数 body 改为调用全局变量 _globalS 的同名方法。
  • 如果有返回值则需要 return 语句。
// rewriteFunc 修改函数的属性
func rewriteFunc(fn *ast.FuncDecl) *ast.FuncDecl {
fn.Recv = nil // 将方法接收者置空,变为函数。
fnName := fn.Name.String()
var args []string
// fn.Type 表示当前当前函数的属性(函数位置、函数的参数列表、函数的返回值)
for _, field := range fn.Type.Params.List {
// 因为 field 可以代表 struct 类型、interface 里面的 method 列表、或者一个参数签名,因此这里的 Names 是一个 List
for _, id := range field.Names {
// 取得参数名
idStr := id.String()
_, ok := field.Type.(*ast.Ellipsis) // 判断当前 field 是否为可变参数
if ok {
// Ellipsis args
idStr += "..."
}
args = append(args, idStr)
}
}

// 函数 body 改为调用 zap.S() 方法。(避免包冲突,这里使用别名)
exprStr := fmt.Sprintf(`zap02.S().%s(%s)`, fnName, strings.Join(args, ","))
expr, err := parser.ParseExpr(exprStr) // 生成这个表达式(方法)的 AST 树
if err != nil {
panic(err)
}
var body []ast.Stmt // 所有 AST Node 都实现了 Stmt
if fn.Type.Results != nil {
body = []ast.Stmt{
// 如果有返回值则需要 return 语句。
&ast.ReturnStmt{
// Return:
Results: []ast.Expr{expr}, // 如果有返回值则需要 return 语句。
},
}
} else {
body = []ast.Stmt{
&ast.ExprStmt{
X: expr,
},
}
}

fn.Body.List = body
return fn
}

ast 转化为 go 代码

单个 func 的 ast 转化为 go 代码,使用 go/format 包:

// ast 转化为 go 代码
func astToGo(dst *bytes.Buffer, node interface{}) error {
addNewline := func() {
err := dst.WriteByte('\n') // add newline
if err != nil {
log.Panicln(err)
}
}

addNewline()
// 单个 func 的 ast 转化为 go 代码,使用 go/format 包:
err := format.Node(dst, token.NewFileSet(), node)
if err != nil {
return err
}
addNewline()
return nil
}

拼装成完整 go file:

// writeGoFile 拼装成完整 go file
func writeGoFile(wr io.Writer, funcs []ast.Decl) error {
// 输出Go代码
header := `// Code generated by log-gen. DO NOT EDIT.
package log
import zap02 "go.uber.org/zap"
`
buffer := bytes.NewBufferString(header)
for _, fn := range funcs {
err := astToGo(buffer, fn)
if err != nil {
return errx.WithStack(err)
}
}

_, err := wr.Write(buffer.Bytes())
return err
}

使用测试

这里的 Walk 方法会按照深度优先搜索方法(depth-first order)遍历整个语法树,我们只需按照我们的业务需要,实现 Visitor 接口即可。 Walk 每遍历一个节点就会调用 Visitor.Visit 方法,传入当前节点。如果 Visit 返回 nil,则停止遍历当前节点的子节点。(具体看遍历那节)

func main() {
err := run()
if err != nil {
log.Fatalln(err)
}
}

func run() error {
dir, err := getImportPkg("go.uber.org/zap")
if err != nil {
return errx.WithStack(err)
}

log.Printf("dir: %+v", dir)

pkg, err := parseDir(dir, "zap")
if err != nil {
return errx.WithStack(err)
}

funcs, err := walkAst(pkg)
if err != nil {
return errx.WithStack(err)
}

err = writeGoFile(os.Stdout, funcs)
if err != nil {
return errx.WithStack(err)
}

return nil
}


func walkAst(node ast.Node) ([]ast.Decl, error) {
v := &visitor{}
ast.Walk(v, node)

log.Printf("funcs len: %d", len(v.funcs))

var decls []ast.Decl
for _, v := range v.funcs {
decls = append(decls, v)
}

return decls, nil
}

这个程序是输出到了 os.Stdout,通过 go:generate 将其重定向到 zap_sugar_generated.go 文件中:

//go:generate sh -c "go run ./generator >zap_sugar_generated.go"

输出:

// Code generated by log-gen. DO NOT EDIT.
package log

import zap02 "go.uber.org/zap"

func Desugar() *Logger {
return zap02.S().Desugar()
}

func Named(name string) *SugaredLogger {
return zap02.S().Named(name)
}

func With(args ...interface{}) *SugaredLogger {
return zap02.S().With(args...)
}

func Debug(args ...interface{}) {
zap02.S().Debug(args...)
}

func Info(args ...interface{}) {
zap02.S().Info(args...)
}

func Warn(args ...interface{}) {
zap02.S().Warn(args...)
}

// ......

这里的返回值可以通过别名来解决找不到目标的情况

// alias
type (
Logger = zap.Logger
SugaredLogger = zap.SugaredLogger
)

References